
import torch.nn as nn

from models.resnet import resnet18
from models.vgg import vgg16
from models.vit_for_small_dataset import ViT

def get_model(args, image_size, num_classes):
    '''Setting up models'''
    if args.model_type == 'resnet18':
        model = resnet18(num_classes=num_classes)
    elif args.model_type == 'vgg16':
        model = vgg16(num_classes=num_classes)
    elif args.model_type == 'vit':
        model = ViT(image_size=(image_size, image_size), patch_size=(4, 4), num_classes=num_classes, dim=512, mlp_dim=1024, dim_head=64, depth=6, heads=12, dropout=0.1, emb_dropout=0.1)
        
    return model

def attach_hook(args, model, hook_fn):
    if args.model_type == 'resnet18':
        _ = model.avgpool.register_forward_hook(hook_fn)
    elif args.model_type == 'vgg16':
        _ = model.avgpool.register_forward_hook(hook_fn)
    elif args.model_type == 'vit':
        _ = model.to_latent.register_forward_hook(hook_fn)  
        
        
        
